from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
import torch
from tqdm import tqdm
import numpy as np
from mri_forward import forward, noiser
import os
from PIL import Image
import torchvision.transforms as transforms
from utils import save_image, calculate_metric, save_coordinate
import argparse

def grad_likelihood(yhat, y, x_prev):
    difference = yhat - y
    norm = torch.linalg.norm(difference)
    grad = torch.autograd.grad(outputs=norm ** 2, inputs=x_prev)[0]
    grad = grad / torch.max(norm)

    return grad

parser = argparse.ArgumentParser(description="Stable Diffusion with DPS")
parser.add_argument("--root_dir", type=str, default="./result", help="Root directory for output")
parser.add_argument("--sigma", type=float, default=0.05, help="Noise level for measurement")
parser.add_argument("--mask_type", type=str, choices=["uniform", "gaussian", "poisson", "hexagonal"], default="uniform", help="Type of mask")
parser.add_argument("--image_type", type=str, choices=["brain", "knee", "celebA"], default="brain", help="Type of image")
parser.add_argument("--acceleration", type=int, choices=[5, 10, 15, 20], default=5, help="Acceleration factor")
parser.add_argument("--zeta", type=float, default=1.4, help="Step size for gradient descent")
parser.add_argument("--num_inference_step", type=int, default=1000, help="Number of inference steps")
parser.add_argument("--guidance_scale", type=float, default=7.5, help="Guidance scale for classifier-free guidance")
parser.add_argument("--save_image_step", type=int, default=100, help="Step interval for saving images")
args = parser.parse_args()

root_dir = args.root_dir
os.makedirs(root_dir, exist_ok=True)

# Define measurement parameters
dtype = torch.float32
sigma = args.sigma
image_size = (512, 512)
mask_type = args.mask_type
is_nufft = True if mask_type == "hexagonal" else False
acceleration = args.acceleration
mask_path = f"./masks/mask_{mask_type}_acc{acceleration}_c32.npy"
image_type = args.image_type
if image_type == "brain":
    image_path = f"./dataset/fastmri_brain/file_brain_AXT2_200_2000056_slice00.png"
elif image_type =="knee":
    image_path = f"./dataset/fastmri_knee/file1000182_slice30.png"
elif image_type =="celebA":
    image_path = f"./dataset/celebA/182340.png"

zeta = args.zeta

num_images_per_prompt = 1
num_inference_step = args.num_inference_step
guidance_scale = args.guidance_scale
do_classifier_free_guidance = guidance_scale > 1.0
save_image_step = args.save_image_step

device = 'cuda' if torch.cuda.is_available() else 'cpu'

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5, 0.5),
    lambda x: x.to(dtype)
])


output_dir = os.path.join(root_dir, f"{image_type}_{mask_type}/acc{acceleration}/sigma{sigma}/zeta{zeta}_steps{num_inference_step}")
os.makedirs(output_dir, exist_ok=True)

# Load mask and image
mask = torch.tensor(np.load(mask_path), dtype=dtype)
if is_nufft is False :
    save_image(mask, os.path.join(output_dir, "mask.png"))
else :
    save_coordinate(mask, os.path.join(output_dir, "mask.png"))
    

img = Image.open(image_path).convert('L')
original_img = transform(img)
save_image(original_img.squeeze(0), os.path.join(output_dir, "ground truth.png"))

# compute measurement
measurement = forward(original_img, mask, is_nufft, acceleration, dtype)
measurement = noiser(measurement, sigma)

save_image(measurement.squeeze(0), os.path.join(output_dir, "measurement.png"))

# Create a file to save the results
result_file_path = os.path.join(output_dir, "result.txt")
with open(result_file_path, "w") as result_file:
    result_file.write(f"Hyperparameters:\n")
    result_file.write(f"Sigma: {sigma}\n")
    result_file.write(f"Mask Path: {mask_path}\n")
    result_file.write(f"Image Path: {image_path}\n")
    result_file.write(f"Zeta: {zeta}\n")
    result_file.write(f"Num Inference Step: {num_inference_step}\n")
    
print("Hyperparameters:")
print(f"Sigma: {sigma}")
print(f"Mask Path: {mask_path}")
print(f"Image Path: {image_path}")
print(f"Zeta: {zeta}")
print(f"Num Inference Step: {num_inference_step}")
print(f"Guidance Scale: {guidance_scale}")
print(f"Save Image Step: {save_image_step}")

###############################################################################################################################################################
model_id = "stabilityai/stable-diffusion-2-1-base"

scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler", torch_dtype=dtype)
pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=dtype)
pipe = pipe.to(device)
unet = pipe.unet

# get prompt text embeddings
prompt = ""
text_input = pipe.tokenizer(
            prompt,
            padding="max_length",
            max_length=pipe.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
text_embeddings = pipe.text_encoder(text_input.input_ids.to(device))[0]
# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)

# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
    max_length = text_input.input_ids.shape[-1]
    uncond_input = pipe.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")
    uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(device))[0]
    # duplicate unconditional embeddings for each generation per prompt
    uncond_embeddings = uncond_embeddings.repeat_interleave(num_images_per_prompt, dim=0)

    # For classifier free guidance, we need to do two forward passes.
    # Here we concatenate the unconditional and text embeddings into a single batch
    # to avoid doing two forward passes
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

latents_shape = (num_images_per_prompt, unet.config.in_channels, unet.config.sample_size, unet.config.sample_size)
latents_dtype = text_embeddings.dtype

latents = torch.randn(latents_shape, generator=torch.manual_seed(0), dtype=latents_dtype).to(device)
    
# set timesteps
scheduler.set_timesteps(num_inference_step)
timesteps_tensor = scheduler.timesteps.to(device)

# scale the initial noise by the standard deviation required by the scheduler
latents = latents * scheduler.init_noise_sigma

mask = mask.to(device)
measurement = measurement.to(device)
measurement = measurement.unsqueeze(0).repeat(num_images_per_prompt, 1, 1, 1)

for i, t in enumerate(pipe.progress_bar(timesteps_tensor)):
    with torch.enable_grad():
        # 1. predict noise model_output
        latents = latents.requires_grad_(True)
        
        # expand the latents if we are doing classifier free guidance
        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        latent_model_input = scheduler.scale_model_input(latent_model_input, t)

        # predict the noise residual
        noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
        
        # perform classifier free guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
        
        # 2. compute previous image x'_{t-1} and original prediction x0_{t}
        scheduler_out = scheduler.step(noise_pred, t, latents)        
        latents_pred, origi_pred = scheduler_out.prev_sample, scheduler_out.pred_original_sample
        
        # 3. compute y'_t = f(x0_{t})
        # scale and decode the image latents with vae
        origi_pred = origi_pred / pipe.vae.config.scaling_factor
        origi_pred = origi_pred.type(dtype)
        origi_image = pipe.vae.decode(origi_pred).sample
        forward_origi_image = forward(origi_image, mask, is_nufft, acceleration, dtype)        
        
        # 4. compute loss = d(y, y'_t-1)
        grad = grad_likelihood(forward_origi_image, measurement, latents)
        latents = latents_pred - zeta * grad
        
        # 5. save image
        if i % save_image_step == 0:
            save_latents = latents.detach()
            save_latents = save_latents / pipe.vae.config.scaling_factor
            image = pipe.vae.decode(save_latents).sample
            image = (image / 2 + 0.5).clamp(0, 1)
            image = image[0].detach().cpu()
            image = transforms.functional.rgb_to_grayscale(image, num_output_channels=1)
            save_image(image.squeeze(0), os.path.join(output_dir, "dps_test_step{}_idx0.png".format(t)))
            ssim_value, psnr_value, lpips_value = calculate_metric(original_img, image)
            with open(result_file_path, "a") as result_file:
                result_file.write(f"Step {t}: SSIM {ssim_value:.4f}, PSNR {psnr_value:.2f}, LPIPS {lpips_value:.4f}\n")
        
# scale and decode the image latents with vae
latents = latents.detach()
latents = latents / pipe.vae.config.scaling_factor
image = pipe.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu()
image = transforms.functional.rgb_to_grayscale(image, num_output_channels=1)
ssim_value, psnr_value, lpips_value = calculate_metric(original_img, image[0])
with open(result_file_path, "a") as result_file:
    result_file.write(f"Step {t}: SSIM {ssim_value:.4f}, PSNR {psnr_value:.2f}, LPIPS {lpips_value:.4f}\n")

for i in range(num_images_per_prompt):
    save_image(image[i].squeeze(0), os.path.join(output_dir, "dps_test_final_idx{}.png".format(i)))
